#!/usr/bin/env python3
import click
from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.np.baselines import LinearFeatureBaseline
from garage.sampler import LocalSampler
from garage.torch.algos import SHARP
from garage.torch.optimizers import OptimizerWrapper
from garage.torch.optimizers.SHARP_optimizer import SHARPOptimizer
from garage.torch.policies import SoftmaxMLPPolicy
from garage.trainer import Trainer

import torch

a = 0.5
b = 3


@click.command()
@click.option('--seed', default=24)
@wrap_experiment(log_dir="/root/Data/aaai/cartpole-SHARP-a={}-b={}".format(a, b), archive_launch_repo=False)
def sharp_cartpole(ctxt=None, seed=1):
    set_seed(seed)
    runner = Trainer(ctxt)

    n_epochs = 100
    sampler_batch_size = 10000

    env = GymEnv('CartPole-v1')

    policy = SoftmaxMLPPolicy(env.spec, hidden_sizes=[64, 64],
                             hidden_nonlinearity=torch.nn.ReLU,
                             output_nonlinearity=None,
                             )

    value_function = LinearFeatureBaseline(env_spec=env.spec)
    sampler = LocalSampler(agents=policy,
                           envs=env,
                           max_episode_length=200,
                           )

    policy_optimizer = OptimizerWrapper((SHARPOptimizer, {"a": a,
                                                          "b": b}), policy)

    algo = SHARP(env_spec=env.spec,
                 policy=policy,
                 value_function=value_function,
                 sampler=sampler,
                 discount=0.99,
                 center_adv=False,
                 policy_optimizer=policy_optimizer,
                 neural_baseline=False

                 )

    runner.setup(algo, env)
    runner.train(n_epochs=n_epochs, batch_size=sampler_batch_size)


sharp_cartpole()
